import soapy
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
from astropy.io import fits
import copy
import pickle

from Global_variable_setting import image_size, patch_size


# define custome exception

class RepeatProcessing(Exception):
    def __init__(self):
        super().__init__("This model has been tested by this program. There is no need to test again. "
                         "And this program will exit for the safe of the existing result data!")


class NoDataError(Exception):
    def __init__(self):
        super().__init__('There is no data for test in this directory!')


# preparing function for normalization
def normalize(data):
    size_of_patch_one_dimension = image_size // patch_size
    for i in range(size_of_patch_one_dimension):
        for j in range(size_of_patch_one_dimension):
            local_area = data[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size]
            if np.max(local_area) != 0:
                local_area = (local_area - np.min(local_area)) / (np.max(local_area) - np.min(local_area))
                data[i * patch_size:(i + 1) * patch_size, j * patch_size:(j + 1) * patch_size] = local_area
            else:
                continue


work_dir = os.getcwd()
config_file = work_dir + '/configuration_dir/generate_data.yaml'

sim = soapy.Sim(config_file)
sim.aoinit()
sim.makeIMat()


# Enter essential parameters for test

while True:
    dataset_type = input('Enter dataset type (normal or perfect):')
    if dataset_type in ['normal', 'perfect']:
        break
if dataset_type == 'perfect':
    model_path_inter = work_dir + '/result/training_result/good_data_for_test'
    result_path_inter = work_dir + '/result/detailed_soapy_test_result/good_data_for_test'
else:
    model_path_inter = work_dir + '/result/training_result'
    result_path_inter = work_dir + '/result/detailed_soapy_test_result'

while True:
    model_type = input('Enter model type (Vision Transformer or CNN):')
    if model_type in ['Vision Transformer', 'CNN']:
        break
model_path_inter = model_path_inter + '/' + model_type
result_path_inter = result_path_inter + '/' + model_type

while True:
    normalized = input('Enter if this modal is normalized (True of False):')
    if normalized in ['True', 'False']:
        break
if normalized == 'True':
    model_path_inter = model_path_inter + '/normalized'
    result_path_inter = result_path_inter + '/normalized'
else:
    model_path_inter = model_path_inter + '/unnormalized'
    result_path_inter = result_path_inter + '/unnormalized'
if len(os.listdir(model_path_inter)) == 0:
    raise NoDataError

while True:
    # if len(os.listdir(model_path_inter)) == 0:
    #     raise NoDataError
    model_index = input(f'Enter the index of model ( in the range of [0, {len(os.listdir(model_path_inter))})):')
    if int(model_index) in range(len(os.listdir(model_path_inter))):
        break
model_path = model_path_inter + '/' + model_index + '/model'
result_path = result_path_inter + '/' + model_index


# make result path
if not os.path.exists(result_path):
    os.makedirs(result_path)
else:
    raise RepeatProcessing


# load modal
model = tf.saved_model.load(model_path)


# start evaluating

# test different magnitude firstly
test_diff_mag_data_path = work_dir + '/testing_data/different_mag'
diff_mag_result_path = result_path + '/different_mag'
os.makedirs(diff_mag_result_path)
diff_mag_overall_direc = {}                  # store overall different magnitude test data
for i in range(8, 19):
    specific_mag_path = diff_mag_result_path + '/star_mag=' + str(i)
    os.makedirs(specific_mag_path)
    # prepare data structure for the storing of statistical information under certain magnitude
    specific_mag_overall_direc = {'wfe_before_with_nm': [], 'wfe_after_with_nm': [],
                                  'inst_strehl_before': [], 'inst_strehl_after': []}
    for j in range(1000):
        # prepare data path
        specific_test_result_path = specific_mag_path + '/' + str(j)
        os.makedirs(specific_test_result_path)
        specific_test_data_path = test_diff_mag_data_path + '/star_mag=' + str(i) + '/' + str(j)
        # prepare information before correction
        scrns = fits.open(specific_test_data_path + '/atmos_scrns.fits')[0].data
        scrns = scrns.astype('float64')
        phase_before_radius = fits.open(specific_test_data_path + '/phase_before_radius.fits')[0].data
        intermediate_image_before = sim.sciCams[0].frame(scrns)
        image_before = copy.deepcopy(intermediate_image_before)
        intermediate_inst_strehl_before = sim.sciCams[0].instStrehl
        inst_strehl_before = copy.deepcopy(intermediate_inst_strehl_before)
        intermediate_wfe_before_with_nm = sim.sciCams[0].calc_wavefronterror()
        wfe_before_with_nm = copy.deepcopy(intermediate_wfe_before_with_nm)
        # prepare information after correction
        if dataset_type == 'normal':
            sh_frame = fits.open(specific_test_data_path + '/distorted_detector.fits')[0].data
        else:
            sh_frame = fits.open(specific_test_data_path + '/perfect_detector.fits')[0].data
        sh_frame = sh_frame.astype('float32')
        if normalized == 'True':
            normalize(sh_frame)
        sh_frame = sh_frame[..., np.newaxis]
        sh_frame = sh_frame[np.newaxis, ...]
        dm_command = model(sh_frame)
        dm_command = dm_command[0]
        dm_shape = sim.dms[0].dmFrame(dm_command)
        dm_shape = dm_shape[np.newaxis, ...]
        intermediate_image_after = sim.sciCams[0].frame(scrns, dm_shape)
        image_after = copy.deepcopy(intermediate_image_after)
        intermediate_wfe_after_with_nm = sim.sciCams[0].calc_wavefronterror()
        wfe_after_with_nm = copy.deepcopy(intermediate_wfe_after_with_nm)
        intermediate_phase_after_nm = sim.wfss[0].los.frame(scrns, dm_shape)
        phase_after_nm = copy.deepcopy(intermediate_phase_after_nm)
        phase_after_radius = phase_after_nm * sim.mask * sim.wfss[0].los.phs2Rad
        intermediate_inst_strehl_after = sim.sciCams[0].instStrehl
        inst_strehl_after = copy.deepcopy(intermediate_inst_strehl_after)
        # plot figure for illustration after these key parameters has generated before and after correction
        variable_list = [phase_before_radius, phase_after_radius, image_before, image_after]
        variable_list_str = ['phase_before_radius', 'phase_after_radius', 'image_before', 'image_after']
        plt.figure()
        for _ in range(4):
            plt.subplot(2, 2, _ + 1)
            if _ == 0:
                plt.imshow(variable_list[_], origin='lower')
                cbar = plt.colorbar()
                plt.title(variable_list_str[_] + '\nwfe(nm)=' + str(wfe_before_with_nm))
            elif _ == 1:
                plt.imshow(variable_list[_], origin='lower', vmin=cbar.vmin, vmax=cbar.vmax)
                plt.colorbar()
                plt.title(variable_list_str[_] + '\nwfe(nm)=' + str(wfe_after_with_nm))
            elif _ == 2:
                plt.imshow(variable_list[_], origin='lower')
                plt.colorbar()
                plt.title(variable_list_str[_] + '\ninst_strehl=' + str(inst_strehl_before))
            else:
                plt.imshow(variable_list[_], origin='lower')
                plt.colorbar()
                plt.title(variable_list_str[_] + '\ninst_strehl=' + str(inst_strehl_after))
        # save necessary information
        plt.savefig(specific_test_result_path + '/correction_comparative_result')
        plt.close('all')
        np.save(specific_test_result_path + '/phase_before_radius', phase_before_radius,
                allow_pickle=False, fix_imports=False)
        np.save(specific_test_result_path + '/phase_after_radius', phase_after_radius,
                allow_pickle=False, fix_imports=False)
        np.save(specific_test_result_path + '/image_before', image_before,
                allow_pickle=False, fix_imports=False)
        np.save(specific_test_result_path + '/image_after', image_after,
                allow_pickle=False, fix_imports=False)
        correction_comparative_directory = {'wfe_before_with_nm': wfe_before_with_nm,
                                            'wfe_after_with_nm': wfe_after_with_nm,
                                            'inst_strehl_before': inst_strehl_before,
                                            'inst_strehl_after': inst_strehl_after}
        with open(specific_test_result_path + '/correction_comparative_directory.txt', 'wb') as input_file:
            pickle.dump(correction_comparative_directory, input_file)

        specific_mag_overall_direc['wfe_before_with_nm'].append(wfe_before_with_nm)
        specific_mag_overall_direc['wfe_after_with_nm'].append(wfe_after_with_nm)
        specific_mag_overall_direc['inst_strehl_before'].append(inst_strehl_before)
        specific_mag_overall_direc['inst_strehl_after'].append(inst_strehl_after)
        # report necessary information to the programmer
        print(f'star_magnitude={i}, the result of the {j} is wfe_after_with_nm={wfe_after_with_nm}, inst_strehl_after={inst_strehl_after}')

    # make path for the saving of statical information under specific magnitude
    specific_mag_overall_result_path = specific_mag_path + '/overall_result'
    os.makedirs(specific_mag_overall_result_path)
    # plot figure and save all the necessary information to describe correction performance
    # under specific magnitude
    plt.figure()

    plt.subplot(2, 2, 1)
    plt.hist(specific_mag_overall_direc['wfe_before_with_nm'])
    plt.xlabel('wavefront error (nm)')
    plt.ylabel('count')
    mean = np.mean(specific_mag_overall_direc['wfe_before_with_nm'])
    std = np.std(specific_mag_overall_direc['wfe_before_with_nm'])
    plt.title(f'wfe before correction with nm histogram\nmean+-std={mean}+-{std}')

    plt.subplot(2, 2, 2)
    plt.hist(specific_mag_overall_direc['wfe_after_with_nm'])
    plt.xlabel('wavefront error (nm)')
    plt.ylabel('count')
    mean = np.mean(specific_mag_overall_direc['wfe_after_with_nm'])
    std = np.std(specific_mag_overall_direc['wfe_after_with_nm'])
    plt.title(f'wfe after correction with nm histogram\nmean+-std={mean}+-{std}')

    plt.subplot(2, 2, 3)
    plt.hist(specific_mag_overall_direc['inst_strehl_before'])
    plt.xlabel('inst strehl before correction')
    plt.ylabel('count')
    mean = np.mean(specific_mag_overall_direc['inst_strehl_before'])
    std = np.std(specific_mag_overall_direc['inst_strehl_before'])
    plt.title(f'inst strehl before correction\nmean+-std={mean}+-{std}')

    plt.subplot(2, 2, 4)
    plt.hist(specific_mag_overall_direc['inst_strehl_after'])
    plt.xlabel('inst strehl after correction')
    plt.ylabel('count')
    mean = np.mean(specific_mag_overall_direc['inst_strehl_after'])
    std = np.std(specific_mag_overall_direc['inst_strehl_after'])
    plt.title(f'inst strehl after correction\nmean+-std={mean}+-{std}')

    plt.tight_layout()

    plt.savefig(specific_mag_overall_result_path + '/overall_result')
    plt.close('all')

    with open(specific_mag_overall_result_path + '/overall_result_direc', 'wb') as file:
        pickle.dump(specific_mag_overall_direc, file)

    # store the overall information under specific magnitude in overall information directory
    # under all different magnitude
    diff_mag_overall_direc['star_mag=' + str(i)] = specific_mag_overall_direc

# make the folder for the saving of overall information under different magnitude
diff_mag_overall_result_path = diff_mag_result_path + '/overall_result'
os.makedirs(diff_mag_overall_result_path)
# make descriptive information by processing overall information directory under
# different magnitude
descriptive_overall_direc_diff_mag = {}
for _ in range(8, 19):
    descriptive_overall_direc_diff_mag['star_mag=' + str(_)] = {}
    for i in ['wfe_before_with_nm', 'wfe_after_with_nm', 'inst_strehl_before', 'inst_strehl_after']:
        descriptive_overall_direc_diff_mag['star_mag=' + str(_)][i] =\
            {'std': np.std(diff_mag_overall_direc['star_mag=' + str(_)][i]),
             'mean': np.mean(diff_mag_overall_direc['star_mag=' + str(_)][i]),
             'median': np.median(diff_mag_overall_direc['star_mag=' + str(_)][i]),
             'max': np.max(diff_mag_overall_direc['star_mag=' + str(_)][i]),
             'min': np.min(diff_mag_overall_direc['star_mag=' + str(_)][i]),
             'P25': np.percentile(diff_mag_overall_direc['star_mag=' + str(_)][i], 25),
             'P75': np.percentile(diff_mag_overall_direc['star_mag=' + str(_)][i], 75),
             }
# plot figure for the display of overall information under different magnitude
metrics_list = ['wfe_before_with_nm', 'wfe_after_with_nm', 'inst_strehl_before', 'inst_strehl_after']
plt.figure()

for i in range(4):
    plt.subplot(2, 2, i + 1)
    plt.errorbar([_ for _ in range(8, 19)],
                 [descriptive_overall_direc_diff_mag['star_mag=' + str(j)][metrics_list[i]]['mean'] for j in range(8, 19)],
                 yerr=[descriptive_overall_direc_diff_mag['star_mag=' + str(j)][metrics_list[i]]['std'] for j in range(8, 19)],
                 )
    plt.xlabel('star magnitude')
    if i in [0, 1]:
        plt.ylabel('wave front error (nm)')
    else:
        plt.ylabel('inst strehl')
    plt.title(metrics_list[i] + ' mean+-std')
plt.tight_layout()
plt.savefig(diff_mag_overall_result_path + '/overall_plot_for_mean+-std')
plt.show()
plt.close('all')

# save overall information under different magnitude
with open(diff_mag_overall_result_path + '/overall_direc', 'wb') as file:
    pickle.dump(diff_mag_overall_direc, file)
with open(diff_mag_overall_result_path + '/overall_direc.txt', 'w') as file:
    file.write(f'diff_mag_overall_direc={diff_mag_overall_direc}')
with open(diff_mag_overall_result_path + '/descriptive_overall_direc_diff_mag', 'wb') as file:
    pickle.dump(descriptive_overall_direc_diff_mag, file)
with open(diff_mag_overall_result_path + '/descriptive_overall_direc_diff_mag.txt', 'w') as file:
    file.write(f'descriptive_overall_direc_diff_mag={descriptive_overall_direc_diff_mag}')


# test different r0 secondly
test_diff_r0_data_path = work_dir + '/testing_data/different_r0'
diff_r0_result_path = result_path + '/different_r0'
os.makedirs(diff_r0_result_path)
diff_r0_overall_direc = {}                  # store overall different r0 test data
for i in [0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20]:
    specific_r0_path = diff_r0_result_path + '/r0=' + str(i)
    os.makedirs(specific_r0_path)
    # prepare data structure for the storing of statistical information under certain r0
    specific_r0_overall_direc = {'wfe_before_with_nm': [],
                                 'wfe_after_with_nm': [],
                                 'inst_strehl_before': [],
                                 'inst_strehl_after': [],
                                 }
    for j in range(1000):
        # prepare data path
        specific_test_result_path = specific_r0_path + '/' + str(j)
        os.makedirs(specific_test_result_path)
        specific_test_data_path = test_diff_r0_data_path + '/r0=' + str(i) + '/' + str(j)
        # prepare information before correction
        scrns = fits.open(specific_test_data_path + '/atmos_scrns.fits')[0].data
        scrns = scrns.astype('float64')
        phase_before_radius = fits.open(specific_test_data_path + '/phase_before_radius.fits')[0].data
        intermediate_image_before = sim.sciCams[0].frame(scrns)
        image_before = copy.deepcopy(intermediate_image_before)
        intermediate_inst_strehl_before = sim.sciCams[0].instStrehl
        inst_strehl_before = copy.deepcopy(intermediate_inst_strehl_before)
        intermediate_wfe_before_with_nm = sim.sciCams[0].calc_wavefronterror()
        wfe_before_with_nm = copy.deepcopy(intermediate_wfe_before_with_nm)
        # prepare information after correction
        if dataset_type == 'normal':
            sh_frame = fits.open(specific_test_data_path + '/distorted_detector.fits')[0].data
        else:
            sh_frame = fits.open(specific_test_data_path + '/perfect_detector.fits')[0].data
        sh_frame = sh_frame.astype('float32')
        if normalized == 'True':
            normalize(sh_frame)
        sh_frame = sh_frame[..., np.newaxis]
        sh_frame = sh_frame[np.newaxis, ...]
        dm_command = model(sh_frame)
        dm_command = dm_command[0]
        dm_shape = sim.dms[0].dmFrame(dm_command)
        dm_shape = dm_shape[np.newaxis, ...]
        intermediate_image_after = sim.sciCams[0].frame(scrns, dm_shape)
        image_after = copy.deepcopy(intermediate_image_after)
        intermediate_wfe_after_with_nm = sim.sciCams[0].calc_wavefronterror()
        wfe_after_with_nm = copy.deepcopy(intermediate_wfe_after_with_nm)
        intermediate_phase_after_nm = sim.wfss[0].los.frame(scrns, dm_shape)
        phase_after_nm = copy.deepcopy(intermediate_phase_after_nm)
        phase_after_radius = phase_after_nm * sim.mask * sim.wfss[0].los.phs2Rad
        intermediate_inst_strehl_after = sim.sciCams[0].instStrehl
        inst_strehl_after = copy.deepcopy(intermediate_inst_strehl_after)
        # plot figure for illustration after these key parameters has generated before and after correction
        variable_list = [phase_before_radius, phase_after_radius, image_before, image_after]
        variable_list_str = ['phase_before_radius', 'phase_after_radius', 'image_before', 'image_after']
        plt.figure()
        for _ in range(4):
            plt.subplot(2, 2, _ + 1)
            if _ == 0:
                plt.imshow(variable_list[_], origin='lower')
                cbar = plt.colorbar()
                plt.title(variable_list_str[_] + '\nwfe(nm)=' + str(wfe_before_with_nm))
            elif _ == 1:
                plt.imshow(variable_list[_], origin='lower', vmin=cbar.vmin, vmax=cbar.vmax)
                plt.colorbar()
                plt.title(variable_list_str[_] + '\nwfe(nm)=' + str(wfe_after_with_nm))
            elif _ == 2:
                plt.imshow(variable_list[_])
                plt.colorbar()
                plt.title(variable_list_str[_] + '\ninst_strehl=' + str(inst_strehl_before))
            else:
                plt.imshow(variable_list[_])
                plt.colorbar()
                plt.title(variable_list_str[_] + '\ninst_strehl=' + str(inst_strehl_after))
        # save necessary information
        plt.savefig(specific_test_result_path + '/correction_comparative_result')
        plt.close('all')
        np.save(specific_test_result_path + '/phase_before_radius', phase_before_radius,
                allow_pickle=False, fix_imports=False)
        np.save(specific_test_result_path + '/phase_after_radius', phase_after_radius,
                allow_pickle=False, fix_imports=False)
        np.save(specific_test_result_path + '/image_before', image_before,
                allow_pickle=False, fix_imports=False)
        np.save(specific_test_result_path + '/image_after', image_after,
                allow_pickle=False, fix_imports=False)
        correction_comparative_directory = {'wfe_before_with_nm': wfe_before_with_nm,
                                            'wfe_after_with_nm': wfe_after_with_nm,
                                            'inst_strehl_before': inst_strehl_before,
                                            'inst_strehl_after': inst_strehl_after}
        with open(specific_test_result_path + '/correction_comparative_directory.txt', 'wb') as input_file:
            pickle.dump(correction_comparative_directory, input_file)

        specific_r0_overall_direc['wfe_before_with_nm'].append(wfe_before_with_nm)
        specific_r0_overall_direc['wfe_after_with_nm'].append(wfe_after_with_nm)
        specific_r0_overall_direc['inst_strehl_before'].append(inst_strehl_before)
        specific_r0_overall_direc['inst_strehl_after'].append(inst_strehl_after)
        # report necessary information to the programmer
        print(f'r0={i}, the result of the {j} is wfe_after_with_nm={wfe_after_with_nm}, inst_strehl_after={inst_strehl_after}')

    # make path for the saving of statical information under specific r0
    specific_r0_overall_result_path = specific_r0_path + '/overall_result'
    os.makedirs(specific_r0_overall_result_path)
    # plot figure and save all the necessary information to describe correction performance
    # under specific r0
    plt.figure()

    plt.subplot(2, 2, 1)
    plt.hist(specific_r0_overall_direc['wfe_before_with_nm'])
    plt.xlabel('wavefront error (nm)')
    plt.ylabel('count')
    mean = np.mean(specific_r0_overall_direc['wfe_before_with_nm'])
    std = np.std(specific_r0_overall_direc['wfe_before_with_nm'])
    plt.title(f'wfe before correction with nm histogram\nmean+-std={mean}+-{std}')

    plt.subplot(2, 2, 2)
    plt.hist(specific_r0_overall_direc['wfe_after_with_nm'])
    plt.xlabel('wavefront error (nm)')
    plt.ylabel('count')
    mean = np.mean(specific_r0_overall_direc['wfe_after_with_nm'])
    std = np.std(specific_r0_overall_direc['wfe_after_with_nm'])
    plt.title(f'wfe after correction with nm histogram\nmean+-std={mean}+-{std}')

    plt.subplot(2, 2, 3)
    plt.hist(specific_r0_overall_direc['inst_strehl_before'])
    plt.xlabel('inst strehl before correction')
    plt.ylabel('count')
    mean = np.mean(specific_r0_overall_direc['inst_strehl_before'])
    std = np.std(specific_r0_overall_direc['inst_strehl_before'])
    plt.title(f'inst strehl before correction\nmean+-std={mean}+-{std}')

    plt.subplot(2, 2, 4)
    plt.hist(specific_r0_overall_direc['inst_strehl_after'])
    plt.xlabel('inst strehl after correction')
    plt.ylabel('count')
    mean = np.mean(specific_r0_overall_direc['inst_strehl_after'])
    std = np.std(specific_r0_overall_direc['inst_strehl_after'])
    plt.title(f'inst strehl after correction\nmean+-std={mean}+-{std}')

    plt.tight_layout()

    plt.savefig(specific_r0_overall_result_path + '/overall_result')
    plt.close('all')

    with open(specific_r0_overall_result_path + '/overall_result_direc', 'wb') as file:
        pickle.dump(specific_r0_overall_direc, file)

    # store the overall information under specific r0 in overall information directory
    # under all different r0
    diff_r0_overall_direc['r0=' + str(i)] = specific_r0_overall_direc

# make the folder for the saving of overall information under different r0
diff_r0_overall_result_path = diff_r0_result_path + '/overall_result'
os.makedirs(diff_r0_overall_result_path)
# make descriptive information by processing overall information directory under
# different r0
descriptive_overall_direc_diff_r0 = {}
for _ in [0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20]:
    descriptive_overall_direc_diff_r0['r0=' + str(_)] = {}
    for i in ['wfe_before_with_nm', 'wfe_after_with_nm', 'inst_strehl_before', 'inst_strehl_after']:
        descriptive_overall_direc_diff_r0['r0=' + str(_)][i] =\
            {'std': np.std(diff_r0_overall_direc['r0=' + str(_)][i]),
             'mean': np.mean(diff_r0_overall_direc['r0=' + str(_)][i]),
             'median': np.median(diff_r0_overall_direc['r0=' + str(_)][i]),
             'max': np.max(diff_r0_overall_direc['r0=' + str(_)][i]),
             'min': np.min(diff_r0_overall_direc['r0=' + str(_)][i]),
             'P25': np.percentile(diff_r0_overall_direc['r0=' + str(_)][i], 25),
             'P75': np.percentile(diff_r0_overall_direc['r0=' + str(_)][i], 75),
             }
# plot figure for the display of overall information under different r0
metrics_list = ['wfe_before_with_nm', 'wfe_after_with_nm', 'inst_strehl_before', 'inst_strehl_after']
plt.figure()

for i in range(4):
    plt.subplot(2, 2, i + 1)
    plt.errorbar([0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
                 [descriptive_overall_direc_diff_r0['r0=' + str(j)][metrics_list[i]]['mean'] for j in [0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20]],
                 yerr=[descriptive_overall_direc_diff_r0['r0=' + str(j)][metrics_list[i]]['std'] for j in [0.05, 0.06, 0.07, 0.08, 0.09, 0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20]],
                 )
    plt.xlabel('atmosphere r0 (cm)')
    if i in [0, 1]:
        plt.ylabel('wave front error (nm)')
    else:
        plt.ylabel('inst strehl')
    plt.title(metrics_list[i] + ' mean+-std')
plt.tight_layout()
plt.savefig(diff_r0_overall_result_path + '/overall_plot_for_mean+-std')
plt.show()
plt.close('all')

# save overall information under different r0
with open(diff_r0_overall_result_path + '/overall_direc', 'wb') as file:
    pickle.dump(diff_r0_overall_direc, file)
with open(diff_r0_overall_result_path + '/overall_direc.txt', 'w') as file:
    file.write(f'diff_r0_overall_direc={diff_r0_overall_direc}')
with open(diff_r0_overall_result_path + '/descriptive_overall_direc_diff_r0', 'wb') as file:
    pickle.dump(descriptive_overall_direc_diff_r0, file)
with open(diff_r0_overall_result_path + '/descriptive_overall_direc_diff_r0.txt', 'w') as file:
    file.write(f'descriptive_overall_direc_diff_r0={descriptive_overall_direc_diff_r0}')



